from __future__ import annotations

import openllm_core, pydantic, typing as t
from openllm_core._configuration import ModelSettings

if t.TYPE_CHECKING:
  import transformers

INSTRUCTION_KEY = '### Instruction:'
RESPONSE_KEY = '### Response:'
END_KEY = '### End'
INTRO_BLURB = (
  'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
)


def get_special_token_id(tokenizer: transformers.PreTrainedTokenizer, key: str) -> int:
  token_ids = tokenizer.encode(key)
  if len(token_ids) > 1:
    raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
  return token_ids[0]


class DollyV2Config(openllm_core.LLMConfig):
  """Databricks` Dolly is an instruction-following large language model trained on the Databricks machine learning platform that is licensed for commercial use.

  Based on pythia-12b, Dolly is trained on ~15k instruction/response fine tuning records databricks-dolly-15k
  generated by Databricks employees in capability domains from the InstructGPT paper, including brainstorming,
  classification, closed QA, generation, information extraction, open QA and summarization.

  dolly-v2-12b is not a state-of-the-art model, but does exhibit surprisingly high quality instruction
  following behavior not characteristic of the foundation model on which it is based.

  Refer to [Databricks's Dolly page](https://github.com/databrickslabs/dolly) for more information.
  """

  model_config = pydantic.ConfigDict(extra='forbid', protected_namespaces=())

  metadata_config: ModelSettings = pydantic.Field(
    default={
      'timeout': 3600000,
      'url': 'https://github.com/databrickslabs/dolly',
      'architecture': 'GPTNeoXForCausalLM',
      'default_id': 'databricks/dolly-v2-3b',
      'model_ids': ['databricks/dolly-v2-3b', 'databricks/dolly-v2-7b', 'databricks/dolly-v2-12b'],
    },
    repr=False,
    exclude=True,
  )

  # NOTE: from get_special_token_id(self.tokenizer, END_KEY)
  generation_config: openllm_core.GenerationConfig = pydantic.Field(
    default=openllm_core.GenerationConfig.model_construct(
      temperature=0.9, top_p=0.92, top_k=5, max_new_tokens=256, eos_token_id=50277
    )
  )

  @property
  def template(self):
    return '{intro}\n{instruction_key}\n{instruction}\n{response_key}\n'.format(
      intro=INTRO_BLURB, instruction_key=INSTRUCTION_KEY, instruction='{instruction}', response_key=RESPONSE_KEY
    )
